{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f5499b54",
   "metadata": {},
   "source": [
    "\n",
    "# Similar Questions Retrieval\n",
    "\n",
    "This notebook is inspired by the [similar search example of Sentence-Transformers](https://www.sbert.net/examples/applications/semantic-search/README.html#similar-questions-retrieval), and adapted to support [RAFT ANN](https://github.com/rapidsai/raft) algorithm.\n",
    "\n",
    "The model was pre-trained on the [Natural Questions dataset](https://ai.google.com/research/NaturalQuestions). It consists of about 100k real Google search queries, together with an annotated passage from Wikipedia that provides the answer. It is an example of an asymmetric search task. As corpus, we use the smaller [Simple English Wikipedia](http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz) so that it fits easily into memory.\n",
    "\n",
    "The steps to install the latest stable `pylibraft` package are available in the [documentation](https://docs.rapids.ai/api/raft/stable/build)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e8d55ede",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: sentence_transformers in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (2.2.2)\n",
      "Requirement already satisfied: torch in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (2.0.1)\n",
      "Requirement already satisfied: transformers<5.0.0,>=4.6.0 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from sentence_transformers) (4.31.0)\n",
      "Requirement already satisfied: tqdm in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from sentence_transformers) (4.65.0)\n",
      "Requirement already satisfied: torchvision in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from sentence_transformers) (0.15.2)\n",
      "Requirement already satisfied: numpy in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from sentence_transformers) (1.24.4)\n",
      "Requirement already satisfied: scikit-learn in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from sentence_transformers) (1.3.0)\n",
      "Requirement already satisfied: scipy in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from sentence_transformers) (1.11.1)\n",
      "Requirement already satisfied: nltk in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from sentence_transformers) (3.8.1)\n",
      "Requirement already satisfied: sentencepiece in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from sentence_transformers) (0.1.99)\n",
      "Requirement already satisfied: huggingface-hub>=0.4.0 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from sentence_transformers) (0.16.4)\n",
      "Requirement already satisfied: filelock in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (3.12.2)\n",
      "Requirement already satisfied: typing-extensions in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (4.7.1)\n",
      "Requirement already satisfied: sympy in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (1.12)\n",
      "Requirement already satisfied: networkx in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (3.1)\n",
      "Requirement already satisfied: jinja2 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (3.1.2)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (11.7.99)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (11.7.99)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (11.7.101)\n",
      "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (8.5.0.96)\n",
      "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (11.10.3.66)\n",
      "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (10.9.0.58)\n",
      "Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (10.2.10.91)\n",
      "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (11.4.0.1)\n",
      "Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (11.7.4.91)\n",
      "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (2.14.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (11.7.91)\n",
      "Requirement already satisfied: triton==2.0.0 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torch) (2.0.0)\n",
      "Requirement already satisfied: setuptools in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (68.0.0)\n",
      "Requirement already satisfied: wheel in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (0.41.0)\n",
      "Requirement already satisfied: cmake in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from triton==2.0.0->torch) (3.27.0)\n",
      "Requirement already satisfied: lit in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from triton==2.0.0->torch) (16.0.6)\n",
      "Requirement already satisfied: fsspec in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (2023.6.0)\n",
      "Requirement already satisfied: requests in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (2.31.0)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (6.0)\n",
      "Requirement already satisfied: packaging>=20.9 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (23.1)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (2023.6.3)\n",
      "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (0.13.3)\n",
      "Requirement already satisfied: safetensors>=0.3.1 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (0.3.1)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from jinja2->torch) (2.1.3)\n",
      "Requirement already satisfied: click in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from nltk->sentence_transformers) (8.1.6)\n",
      "Requirement already satisfied: joblib in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from nltk->sentence_transformers) (1.3.0)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from scikit-learn->sentence_transformers) (3.2.0)\n",
      "Requirement already satisfied: mpmath>=0.19 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from torchvision->sentence_transformers) (10.0.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (3.2.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (3.4)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (2.0.4)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /raid/danteg/miniconda3/envs/raft-ann/lib/python3.10/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (2023.7.22)\n"
     ]
    }
   ],
   "source": [
    "!pip install sentence_transformers torch\n",
    "\n",
    "# Note: if you have a Hopper based GPU, like an H100, use these to install:\n",
    "# pip install torch --index-url https://download.pytorch.org/whl/cu118\n",
    "# pip install sentence_transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "eb1e81c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mon Jul 31 14:35:31 2023       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  NVIDIA H100 80G...  On   | 00000000:1B:00.0 Off |                    0 |\n",
      "| N/A   30C    P0    75W / 700W |      0MiB / 81559MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   1  NVIDIA H100 80G...  On   | 00000000:43:00.0 Off |                    0 |\n",
      "| N/A   31C    P0    72W / 700W |      0MiB / 81559MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   2  NVIDIA H100 80G...  On   | 00000000:52:00.0 Off |                    0 |\n",
      "| N/A   34C    P0    70W / 700W |      0MiB / 81559MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   3  NVIDIA H100 80G...  On   | 00000000:61:00.0 Off |                    0 |\n",
      "| N/A   33C    P0    70W / 700W |      0MiB / 81559MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   4  NVIDIA H100 80G...  On   | 00000000:9D:00.0 Off |                    0 |\n",
      "| N/A   32C    P0    74W / 700W |      0MiB / 81559MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   5  NVIDIA H100 80G...  On   | 00000000:C3:00.0 Off |                    0 |\n",
      "| N/A   30C    P0    73W / 700W |      0MiB / 81559MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   6  NVIDIA H100 80G...  On   | 00000000:D1:00.0 Off |                    0 |\n",
      "| N/A   33C    P0    73W / 700W |      0MiB / 81559MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   7  NVIDIA H100 80G...  On   | 00000000:DF:00.0 Off |                    0 |\n",
      "| N/A   35C    P0    73W / 700W |      0MiB / 81559MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "|  No running processes found                                                 |\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ee4c5cc0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/raid/danteg/miniconda3/envs/raftann/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "from sentence_transformers import SentenceTransformer, CrossEncoder, util\n",
    "import time\n",
    "import gzip\n",
    "import os\n",
    "import torch\n",
    "import pylibraft\n",
    "from pylibraft.neighbors import ivf_flat, ivf_pq\n",
    "pylibraft.config.set_output_as(lambda device_ndarray: device_ndarray.copy_to_host())\n",
    "\n",
    "if not torch.cuda.is_available():\n",
    "  print(\"Warning: No GPU found. Please add GPU to your notebook\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0a1a6307",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Passages: 509663\n"
     ]
    }
   ],
   "source": [
    "# We use the Bi-Encoder to encode all passages, so that we can use it with semantic search\n",
    "model_name = 'nq-distilbert-base-v1'\n",
    "bi_encoder = SentenceTransformer(model_name)\n",
    "\n",
    "# As dataset, we use Simple English Wikipedia. Compared to the full English wikipedia, it has only\n",
    "# about 170k articles. We split these articles into paragraphs and encode them with the bi-encoder\n",
    "\n",
    "wikipedia_filepath = 'data/simplewiki-2020-11-01.jsonl.gz'\n",
    "\n",
    "if not os.path.exists(wikipedia_filepath):\n",
    "    util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath)\n",
    "\n",
    "passages = []\n",
    "with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:\n",
    "    for line in fIn:\n",
    "        data = json.loads(line.strip())\n",
    "        for paragraph in data['paragraphs']:\n",
    "            # We encode the passages as [title, text]\n",
    "            passages.append([data['title'], paragraph])\n",
    "\n",
    "# If you like, you can also limit the number of passages you want to use\n",
    "print(\"Passages:\", len(passages))\n",
    "\n",
    "# To speed things up, pre-computed embeddings are downloaded.\n",
    "# The provided file encoded the passages with the model 'nq-distilbert-base-v1'\n",
    "if model_name == 'nq-distilbert-base-v1':\n",
    "    embeddings_filepath = 'simplewiki-2020-11-01-nq-distilbert-base-v1.pt'\n",
    "    if not os.path.exists(embeddings_filepath):\n",
    "        util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01-nq-distilbert-base-v1.pt', embeddings_filepath)\n",
    "\n",
    "    corpus_embeddings = torch.load(embeddings_filepath)\n",
    "    corpus_embeddings = corpus_embeddings.float()  # Convert embedding file to float\n",
    "    if torch.cuda.is_available():\n",
    "        corpus_embeddings = corpus_embeddings.to('cuda')\n",
    "else:  # Here, we compute the corpus_embeddings from scratch (which can take a while depending on the GPU)\n",
    "    corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f4e9b9d",
   "metadata": {},
   "source": [
    "# Vector Search using RAPIDS RAFT\n",
    "Now that our embeddings are ready to be indexed and that the model has been loaded, we can use RAPIDS RAFT to do our vector search.\n",
    "\n",
    "This is done in two step: First we build the index, then we search it.\n",
    "With `pylibraft` all you need is those four Python lines:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ad90b4be",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[W] [14:35:48.810785] [raft::ivf_pq::build] the default cuda resource is used for the raft workspace allocations. This may lead to a significant slowdown for this algorithm. Consider using the default pool resource (`raft::resource::set_workspace_to_pool_resource`) or set your own resource explicitly (`raft::resource::set_workspace_resource`).\n",
      "[W] [14:35:53.831753] [raft::ivf_pq::extend] the default cuda resource is used for the raft workspace allocations. This may lead to a significant slowdown for this algorithm. Consider using the default pool resource (`raft::resource::set_workspace_to_pool_resource`) or set your own resource explicitly (`raft::resource::set_workspace_resource`).\n",
      "CPU times: user 2.21 s, sys: 2.49 s, total: 4.7 s\n",
      "Wall time: 5.13 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "params = ivf_pq.IndexParams(n_lists=150, pq_dim=96)\n",
    "pq_index = ivf_pq.build(params, corpus_embeddings)\n",
    "search_params = ivf_pq.SearchParams()\n",
    "\n",
    "def search_raft_pq(query, top_k = 5):\n",
    "    # Encode the query using the bi-encoder and find potentially relevant passages\n",
    "    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n",
    "\n",
    "    hits = ivf_pq.search(search_params, pq_index, question_embedding[None], top_k)\n",
    "\n",
    "    # Output of top-k hits\n",
    "    print(\"Input question:\", query)\n",
    "    for k in range(top_k):\n",
    "        print(\"\\t{:.3f}\\t{}\".format(hits[0][0, k], passages[hits[1][0, k]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07935bca",
   "metadata": {},
   "source": [
    "For IVF-PQ we want to reduce the memory footprint while keeping a good accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "724dcacb",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "IVF-PQ memory footprint: 373.3 MB\n",
      "Original dataset: 1493.2 MB\n",
      "Memory saved: 75.0%\n"
     ]
    }
   ],
   "source": [
    "pq_index_mem = pq_index.pq_dim * pq_index.size * pq_index.pq_bits\n",
    "print(\"IVF-PQ memory footprint: {:.1f} MB\".format(pq_index_mem / 2**20))\n",
    "\n",
    "original_mem = corpus_embeddings.shape[0] * corpus_embeddings.shape[1] * 4\n",
    "print(\"Original dataset: {:.1f} MB\".format(original_mem / 2**20))\n",
    "\n",
    "print(\"Memory saved: {:.1f}%\".format(100 * (1 - pq_index_mem / original_mem)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c27d4715",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[W] [14:36:07.640223] [raft::ivf_pq::search] the default cuda resource is used for the raft workspace allocations. This may lead to a significant slowdown for this algorithm. Consider using the default pool resource (`raft::resource::set_workspace_to_pool_resource`) or set your own resource explicitly (`raft::resource::set_workspace_resource`).\n",
      "Input question: Who was Grace Hopper?\n",
      "\t190.855\t['Leona Helmsley', 'Leona Helmsley (July 4, 1920 – August 20, 2007) was an American businesswoman. She was known for having a flamboyant personality. She had a reputation for tyrannical behavior; she was nicknamed the Queen of Mean.']\n",
      "\t195.364\t['Grace Hopper', 'Hopper was born in New York, USA. Hopper graduated from Vassar College in 1928 and Yale University in 1934 with a Ph.D degree in mathematics. She joined the US Navy during the World War II in 1943. She worked on computers in the Navy for 43 years. She then worked in other private industry companies after 1949. She retired from the Navy in 1986 and died on January 1, 1992.']\n",
      "\t202.536\t['Anita Borg', 'Anita Borg (January 17, 1949 – April 6, 2003) was an American computer scientist. She founded the Institute for Women and Technology and the Grace Hopper Celebration of Women in Computing.']\n",
      "\t203.717\t['Brett Butler', 'Brett Butler (born January 30, 1958) is an American actress and stand-up comedian. She is best known for playing Grace in the sitcom \"Grace Under Fire\". She has also done other television programs and comedy acts.']\n",
      "\t203.991\t['Nellie Bly', 'Elizabeth Cochrane Seaman (born Elizabeth Jane Cochran; May 5, 1864 – January 27, 1922), better known by her pen name Nellie Bly, was an American journalist, novelist and inventor. She was a newspaper reporter, who worked at various jobs for exposing poor working conditions. Nellie Bly, also, fought for women\\'s right and was known for investigative reporting. She best known for her record-breaking trip around the world in 72 days, inspired by the adventure novel \"Around the World in Eighty Days\" by Jules Verne. In the 1880s, she went undercover as a mentally ill patient in a psychiatric hospital for ten days, with the report being made public in a book called \"\"Ten Days in a Mad-House\"\". She was added to the National Women\\'s Hall of Fame in 1998.']\n",
      "CPU times: user 98.3 ms, sys: 81.2 ms, total: 180 ms\n",
      "Wall time: 120 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "search_raft_pq(query=\"Who was Grace Hopper?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bc375518",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input question: Who was Alan Turing?\n",
      "\t139.827\t['Alan Turing', 'Alan Mathison Turing OBE FRS (London, 23 June 1912 – Wilmslow, Cheshire, 7 June 1954) was an English mathematician and computer scientist. He was born in Maida Vale, London.']\n",
      "\t169.849\t['William Kahan', 'William Morton Kahan (born June 5, 1933) is a Canadian mathematician and computer scientist. He received the Turing Award in 1989 for \"\"his fundamental contributions to numerical analysis\".\" He was named an ACM Fellow in 1994, and added to the National Academy of Engineering in 2005.']\n",
      "\t177.520\t['Rolf Noskwith', 'Rolf Noskwith (19 June 1919 – 3 January 2017) was a British businessman. During the Second World War, he worked under Alan Turing as a cryptographer at the British military base Bletchley Park in Milton Keynes, Buckinghamshire.']\n",
      "\t179.202\t['Marvin Minsky', \"Marvin Lee Minsky (August 9, 1927 – January 24, 2016) was an American cognitive scientist in the field of artificial intelligence (AI). He was the co-founder of the Massachusetts Institute of Technology's AI laboratory, and author of several texts on AI and philosophy. He won the Turing Award in 1969.\"]\n",
      "\t179.819\t['Edsger W. Dijkstra', 'Edsger Wybe Dijkstra (May 11, 1930 – August 6, 2002; ) was a Dutch computer scientist. He received the 1972 Turing Award for fundamental contributions to developing programming languages, and was the Schlumberger Centennial Chair of Computer Sciences at The University of Texas at Austin from 1984 until 2000.']\n",
      "CPU times: user 4.89 ms, sys: 7.52 ms, total: 12.4 ms\n",
      "Wall time: 12 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "search_raft_pq(query=\"Who was Alan Turing?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ab154181",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input question: What is creating tides?\n",
      "\t125.037\t['Tide', \"A tide is the periodic rising and falling of Earth's ocean surface caused mainly by the gravitational pull of the Moon acting on the oceans. Tides cause changes in the depth of marine and estuarine (river mouth) waters. Tides also make oscillating currents known as tidal streams (~'rip tides'). This means that being able to predict the tide is important for coastal navigation. The strip of seashore that is under water at high tide and exposed at low tide, called the intertidal zone, is an important ecological product of ocean tides.\"]\n",
      "\t163.835\t['Tidal energy', \"Many things affect tides. The pull of the Moon is the largest effect, and most of the energy comes from the slowing of the Earth's spin.\"]\n",
      "\t167.368\t['Storm surge', 'A storm surge is a sudden rise of water hitting areas close to the coast. Storm surges are usually created by a hurricane or other tropical cyclone. The surge happens because a storm has fast winds and low atmospheric pressure. Water is pushed on shore, and the water level rises. Strong storm surges can flood coastal towns and destroy homes. A storm surge is considered the deadliest part of a hurricane. They kill many people each year.']\n",
      "\t177.143\t['Tidal force', 'Tidal force is caused by gravity and makes tides happen. This is because the gravitational field changes across the middle of a body (the diameter).']\n",
      "\t186.108\t['Tsunami', \"A tsunami is a natural disaster which is a series of fast-moving waves in the ocean caused by powerful earthquakes, volcanic eruptions, landslides, or simply an asteroid or a meteor crash inside the ocean. A tsunami has a very long wavelength. It can be hundreds of kilometers long. Usually, a tsunami starts suddenly. The waves travel at a great speed across an ocean with little energy loss. They can remove sand from beaches, destroy trees, toss and drag vehicles, houses and even destroy whole towns. Tsunamis can even be caused when a meteorite strikes the earth's surface, though it is very rare. A tsunami normally occurs in the Pacific Ocean, especially in what is called the ring of fire, but can occur in any large body of water.\"]\n",
      "CPU times: user 4.44 ms, sys: 4.65 ms, total: 9.09 ms\n",
      "Wall time: 12.4 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "search_raft_pq(query = \"What is creating tides?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2d6017ed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 208 ms, sys: 63.8 ms, total: 271 ms\n",
      "Wall time: 286 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "params = ivf_flat.IndexParams(n_lists=150)\n",
    "flat_index = ivf_flat.build(params, corpus_embeddings)\n",
    "search_params = ivf_flat.SearchParams()\n",
    "\n",
    "def search_raft_flat(query, top_k = 5):\n",
    "    # Encode the query using the bi-encoder and find potentially relevant passages\n",
    "    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n",
    "    \n",
    "    start_time = time.time()\n",
    "    hits = ivf_flat.search(search_params, flat_index, question_embedding[None], top_k)\n",
    "    end_time = time.time()\n",
    "\n",
    "    # Output of top-k hits\n",
    "    print(\"Input question:\", query)\n",
    "    print(\"Results (after {:.3f} seconds):\".format(end_time - start_time))\n",
    "    for k in range(top_k):\n",
    "        print(\"\\t{:.3f}\\t{}\".format(hits[0][0, k], passages[hits[1][0, k]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f5cfb644",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input question: Who was Grace Hopper?\n",
      "Results (after 0.002 seconds):\n",
      "\t181.650\t['Grace Hopper', 'Hopper was born in New York, USA. Hopper graduated from Vassar College in 1928 and Yale University in 1934 with a Ph.D degree in mathematics. She joined the US Navy during the World War II in 1943. She worked on computers in the Navy for 43 years. She then worked in other private industry companies after 1949. She retired from the Navy in 1986 and died on January 1, 1992.']\n",
      "\t192.946\t['Leona Helmsley', 'Leona Helmsley (July 4, 1920 – August 20, 2007) was an American businesswoman. She was known for having a flamboyant personality. She had a reputation for tyrannical behavior; she was nicknamed the Queen of Mean.']\n",
      "\t194.951\t['Grace Hopper', 'Grace Murray Hopper (December 9 1906 – January 1 1992) was an American computer scientist and United States Navy officer.']\n",
      "\t202.192\t['Nellie Bly', 'Elizabeth Cochrane Seaman (born Elizabeth Jane Cochran; May 5, 1864 – January 27, 1922), better known by her pen name Nellie Bly, was an American journalist, novelist and inventor. She was a newspaper reporter, who worked at various jobs for exposing poor working conditions. Nellie Bly, also, fought for women\\'s right and was known for investigative reporting. She best known for her record-breaking trip around the world in 72 days, inspired by the adventure novel \"Around the World in Eighty Days\" by Jules Verne. In the 1880s, she went undercover as a mentally ill patient in a psychiatric hospital for ten days, with the report being made public in a book called \"\"Ten Days in a Mad-House\"\". She was added to the National Women\\'s Hall of Fame in 1998.']\n",
      "\t205.038\t['Abbie Hoffman', 'Abbot Howard \"Abbie\" Hoffman (November 30, 1936 – April 12, 1989) was an American social and political activist.']\n",
      "CPU times: user 6.48 ms, sys: 0 ns, total: 6.48 ms\n",
      "Wall time: 6.22 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "search_raft_flat(query=\"Who was Grace Hopper?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b5694d00",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input question: Who was Alan Turing?\n",
      "Results (after 0.002 seconds):\n",
      "\t106.131\t['Alan Turing', 'Alan Mathison Turing OBE FRS (London, 23 June 1912 – Wilmslow, Cheshire, 7 June 1954) was an English mathematician and computer scientist. He was born in Maida Vale, London.']\n",
      "\t158.646\t['William Kahan', 'William Morton Kahan (born June 5, 1933) is a Canadian mathematician and computer scientist. He received the Turing Award in 1989 for \"\"his fundamental contributions to numerical analysis\".\" He was named an ACM Fellow in 1994, and added to the National Academy of Engineering in 2005.']\n",
      "\t165.094\t['Alan Turing', 'A brilliant mathematician and cryptographer Alan was to become the founder of modern-day computer science and artificial intelligence; designing a machine at Bletchley Park to break secret Enigma encrypted messages used by the Nazi German war machine to protect sensitive commercial, diplomatic and military communications during World War 2. Thus, Turing made the single biggest contribution to the Allied victory in the war against Nazi Germany, possibly saving the lives of an estimated 2 million people, through his effort in shortening World War II.']\n",
      "\t167.321\t['Rolf Noskwith', 'Rolf Noskwith (19 June 1919 – 3 January 2017) was a British businessman. During the Second World War, he worked under Alan Turing as a cryptographer at the British military base Bletchley Park in Milton Keynes, Buckinghamshire.']\n",
      "\t176.480\t['Marvin Minsky', \"Marvin Lee Minsky (August 9, 1927 – January 24, 2016) was an American cognitive scientist in the field of artificial intelligence (AI). He was the co-founder of the Massachusetts Institute of Technology's AI laboratory, and author of several texts on AI and philosophy. He won the Turing Award in 1969.\"]\n",
      "CPU times: user 4.81 ms, sys: 1.19 ms, total: 6 ms\n",
      "Wall time: 6.06 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "search_raft_flat(query=\"Who was Alan Turing?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fcfc3c5b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input question: What is creating tides?\n",
      "Results (after 0.002 seconds):\n",
      "\t94.909\t['Tide', \"A tide is the periodic rising and falling of Earth's ocean surface caused mainly by the gravitational pull of the Moon acting on the oceans. Tides cause changes in the depth of marine and estuarine (river mouth) waters. Tides also make oscillating currents known as tidal streams (~'rip tides'). This means that being able to predict the tide is important for coastal navigation. The strip of seashore that is under water at high tide and exposed at low tide, called the intertidal zone, is an important ecological product of ocean tides.\"]\n",
      "\t159.539\t['Tidal energy', \"Many things affect tides. The pull of the Moon is the largest effect, and most of the energy comes from the slowing of the Earth's spin.\"]\n",
      "\t159.740\t['Storm surge', 'A storm surge is a sudden rise of water hitting areas close to the coast. Storm surges are usually created by a hurricane or other tropical cyclone. The surge happens because a storm has fast winds and low atmospheric pressure. Water is pushed on shore, and the water level rises. Strong storm surges can flood coastal towns and destroy homes. A storm surge is considered the deadliest part of a hurricane. They kill many people each year.']\n",
      "\t178.283\t['Sea', 'Wind blowing over the surface of a body of water forms waves. The friction between air and water caused by a gentle breeze on a pond causes ripples to form. A strong blow over the ocean causes larger waves as the moving air pushes against the raised ridges of water. The waves reach their greatest height when the rate at which they travel nearly matches the speed of the wind. The waves form at right angles to the direction from which the wind blows. In open water, if the wind continues to blow, as happens in the Roaring Forties in the southern hemisphere, long, organized masses of water called swell roll across the ocean. If the wind dies down, the wave formation is reduced but waves already formed continue to travel in their original direction until they meet land. Small waves form in small areas of water with islands and other landmasses but large waves form in open stretches of sea where the wind blows steadily and strongly. When waves meet other waves coming from different directions, interference between the two can produce broken, irregular seas.']\n",
      "\t181.498\t['Tidal force', 'Tidal force is caused by gravity and makes tides happen. This is because the gravitational field changes across the middle of a body (the diameter).']\n",
      "CPU times: user 5.91 ms, sys: 0 ns, total: 5.91 ms\n",
      "Wall time: 5.65 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "search_raft_flat(query = \"What is creating tides?\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a59d7b32-0832-4c3a-864e-aeb2e6e7fe1f",
   "metadata": {},
   "source": [
    "## Using CAGRA: GPU graph-based Vector Search\n",
    "\n",
    "CAGRA is a graph-based nearest neighbors implementation with state-of-the art query performance for both small- and large-batch sized vector searches. \n",
    "\n",
    "CAGRA follows the same two-step APIs as IVF-FLAT and IVF-PQ in RAFT. First we build the index:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "50df1f43-c580-4019-949a-06bdc7185536",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pylibraft.neighbors import cagra"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "091cde52-4652-4230-af2b-75c35357f833",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 35.3 s, sys: 4.5 s, total: 39.8 s\n",
      "Wall time: 2.16 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "params = cagra.IndexParams(intermediate_graph_degree=32, graph_degree=16, build_algo=\"nn_descent\")\n",
    "cagra_index = cagra.build(params, corpus_embeddings)\n",
    "search_params = cagra.SearchParams(algo=\"multi_cta\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "df229e21-f6b6-4d6c-ad54-2724f8738934",
   "metadata": {},
   "outputs": [],
   "source": [
    "def search_raft_cagra(query, top_k = 5):\n",
    "    # Encode the query using the bi-encoder and find potentially relevant passages\n",
    "    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n",
    "\n",
    "    start_time = time.time()\n",
    "    hits = cagra.search(search_params, cagra_index, question_embedding[None], top_k)\n",
    "    end_time = time.time()\n",
    "\n",
    "    # Output of top-k hits\n",
    "    print(\"Results (after {:.3f} seconds):\".format(end_time - start_time))\n",
    "    print(\"Input question:\", query)\n",
    "    for k in range(top_k):\n",
    "        print(\"\\t{:.3f}\\t{}\".format(hits[0][0, k], passages[hits[1][0, k]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b5e862fd-b7e5-4423-8fbf-36918f02c8f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results (after 0.005 seconds):\n",
      "Input question: Who was Grace Hopper?\n",
      "\t181.649\t['Grace Hopper', 'Hopper was born in New York, USA. Hopper graduated from Vassar College in 1928 and Yale University in 1934 with a Ph.D degree in mathematics. She joined the US Navy during the World War II in 1943. She worked on computers in the Navy for 43 years. She then worked in other private industry companies after 1949. She retired from the Navy in 1986 and died on January 1, 1992.']\n",
      "\t192.946\t['Leona Helmsley', 'Leona Helmsley (July 4, 1920 – August 20, 2007) was an American businesswoman. She was known for having a flamboyant personality. She had a reputation for tyrannical behavior; she was nicknamed the Queen of Mean.']\n",
      "\t194.951\t['Grace Hopper', 'Grace Murray Hopper (December 9 1906 – January 1 1992) was an American computer scientist and United States Navy officer.']\n",
      "\t202.192\t['Nellie Bly', 'Elizabeth Cochrane Seaman (born Elizabeth Jane Cochran; May 5, 1864 – January 27, 1922), better known by her pen name Nellie Bly, was an American journalist, novelist and inventor. She was a newspaper reporter, who worked at various jobs for exposing poor working conditions. Nellie Bly, also, fought for women\\'s right and was known for investigative reporting. She best known for her record-breaking trip around the world in 72 days, inspired by the adventure novel \"Around the World in Eighty Days\" by Jules Verne. In the 1880s, she went undercover as a mentally ill patient in a psychiatric hospital for ten days, with the report being made public in a book called \"\"Ten Days in a Mad-House\"\". She was added to the National Women\\'s Hall of Fame in 1998.']\n",
      "\t205.038\t['Abbie Hoffman', 'Abbot Howard \"Abbie\" Hoffman (November 30, 1936 – April 12, 1989) was an American social and political activist.']\n",
      "CPU times: user 4.18 ms, sys: 3.88 ms, total: 8.07 ms\n",
      "Wall time: 9.97 ms\n"
     ]
    }
   ],
   "source": [
    "%%time \n",
    "search_raft_cagra(query=\"Who was Grace Hopper?\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
